Trying out TensorFlow StochasticTensors
Write your post here.
Write your post here.
Thrilled to be attending #icml2017 right here at home in Sydney! Excited to meet and learn from a lot of brilliant minds this week. pic.twitter.com/jazfvgabCq
— Louis Tiao (@louistiao) August 5, 2017
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data as mnist_data
tf.__version__
sess = tf.InteractiveSession()
mnist = mnist_data.read_data_sets("/home/tiao/Desktop/MNIST")
# 50 single-channel (grayscale) 28x28 images
x = mnist.train.images[:50].reshape(-1, 28, 28, 1)
x.shape
fig, ax = plt.subplots(figsize=(5, 5))
# showing an arbitrarily chosen image
ax.imshow(np.squeeze(x[5], axis=-1), cmap='gray')
plt.show()
conv2d¶
# 32 kernels of size 5x5x1
kernel = tf.truncated_normal([5, 5, 1, 32], stddev=0.1)
kernel.get_shape().as_list()
x_conved = tf.nn.conv2d(x, kernel,
strides=[1, 1, 1, 1],
padding='SAME')
x_conved.get_shape().as_list()
x_conved[5, ..., 0].eval().shape
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(9, 4))
# showing what the 0th filter looks like
ax1.imshow(kernel[..., 0, 0].eval(), cmap='gray')
# show the previous arbitrarily chosen image
# convolved with the 0th filter
ax2.imshow(x_conved[5, ..., 0].eval(), cmap='gray')
plt.show()
# 8x32 kernels of size 5x5x1
kernels = tf.truncated_normal([8, 5, 5, 1, 32], stddev=0.1)
kernels.get_shape().as_list()
conv2d¶
x_tiled = tf.tile(tf.expand_dims(x, 0), [8, 1, 1, 1, 1])
x_tiled.get_shape().as_list()
tf.nn.conv2d(x_tiled[0], kernels[0],
strides=[1, 1, 1, 1],
padding='SAME').get_shape().as_list()
x_conved1 = tf.map_fn(lambda args: tf.nn.conv2d(*args, strides=[1, 1, 1, 1], padding='SAME'),
elems=(x_tiled, kernels), dtype=tf.float32)
x_conved1.get_shape().as_list()
kernels_flat = tf.reshape(tf.transpose(kernels,
perm=(1, 2, 3, 4, 0)),
shape=(5, 5, 1, 32*8))
kernels_flat.get_shape().as_list()
x_conved2 = tf.transpose(tf.reshape(tf.nn.conv2d(x, kernels_flat,
strides=[1, 1, 1, 1],
padding='SAME'),
shape=(50, 28, 28, 32, 8)),
perm=(4, 0, 1, 2, 3))
x_conved2.get_shape().as_list()
tf.reduce_all(tf.equal(x_conved1, x_conved2)).eval()
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
import numpy as np
import keras.backend as K
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import logistic, multivariate_normal, norm
from scipy.special import expit
from keras.models import Model, Sequential
from keras.layers import Activation, Add, Dense, Dot, Input
from keras.optimizers import Adam
from keras.utils.vis_utils import model_to_dot
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FuncAnimation
from IPython.display import HTML, SVG, display_html
from tqdm import tnrange, tqdm_notebook
# display animation inline
plt.rc('animation', html='html5')
plt.style.use('seaborn-notebook')
sns.set_context('notebook')
np.set_printoptions(precision=2,
edgeitems=3,
linewidth=80,
suppress=True)
K.tf.__version__
LATENT_DIM = 2
NOISE_DIM = 3
BATCH_SIZE = 200
PRIOR_VARIANCE = 2.
LEARNING_RATE = 3e-3
PRETRAIN_EPOCHS = 60
z_min, z_max = -5, 5
z1, z2 = np.mgrid[z_min:z_max:300j, z_min:z_max:300j]
z_grid = np.dstack((z1, z2))
z_grid.shape
prior = multivariate_normal(mean=np.zeros(LATENT_DIM),
cov=PRIOR_VARIANCE)
log_prior = prior.logpdf(z_grid)
log_prior.shape
np.allclose(log_prior,
-.5*np.sum(z_grid**2, axis=2)/PRIOR_VARIANCE \
-np.log(2*np.pi*PRIOR_VARIANCE))
fig, ax = plt.subplots(figsize=(5, 5))
ax.contourf(z1, z2, log_prior, cmap='magma')
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(z_min, z_max)
ax.set_ylim(z_min, z_max)
plt.show()
x = np.array([0, 5, 8, 12, 50])
def log_likelihood(z, x, beta_0=3., beta_1=1.):
beta = beta_0 + np.sum(beta_1*np.maximum(0, z**3), axis=-1)
return -np.log(beta) - x/beta
llhs = log_likelihood(z_grid, x.reshape(-1, 1, 1))
llhs.shape
fig, axes = plt.subplots(ncols=len(x), nrows=1, figsize=(20, 4))
fig.tight_layout()
for i, ax in enumerate(axes):
ax.contourf(z1, z2, llhs[i,::,::], cmap=plt.cm.magma)
ax.set_xlim(z_min, z_max)
ax.set_ylim(z_min, z_max)
ax.set_title('$p(x = {{{0}}} \mid z)$'.format(x[i]))
ax.set_xlabel('$z_1$')
if not i:
ax.set_ylabel('$z_2$')
plt.show()
fig, ax = plt.subplots(figsize=(5, 5))
ax.contourf(z1, z2, np.sum(llhs, axis=0),
cmap=plt.cm.magma)
ax.set_xlabel('$z_1$')
ax.set_ylabel('$z_2$')
ax.set_xlim(z_min, z_max)
ax.set_ylim(z_min, z_max)
plt.show()
fig, axes = plt.subplots(ncols=len(x), nrows=1, figsize=(20, 4))
fig.tight_layout()
for i, ax in enumerate(axes):
ax.contourf(z1, z2, np.exp(log_prior+llhs[i,::,::]),
cmap='magma')
ax.set_xlim(z_min, z_max)
ax.set_ylim(z_min, z_max)
ax.set_title('$Zp(z \mid x = {{{0}}})$'.format(x[i]))
ax.set_xlabel('$z_1$')
if not i:
ax.set_ylabel('$z_2$')
plt.show()
fig, ax = plt.subplots(figsize=(5, 5))
ax.contourf(z1, z2,
np.exp(log_prior+np.sum(llhs, axis=0)),
cmap='magma')
ax.set_xlabel('$z_1$')
ax.set_ylabel('$z_2$')
ax.set_xlim(z_min, z_max)
ax.set_ylim(z_min, z_max)
plt.show()
$T_{\psi}(x, z)$
x_input = Input(shape=(1,), name='x')
x_hidden = Dense(10, activation='relu')(x_input)
x_hidden = Dense(20, activation='relu')(x_hidden)
z_input = Input(shape=(LATENT_DIM,), name='z')
z_hidden = Dense(10, activation='relu')(z_input)
z_hidden = Dense(20, activation='relu')(z_hidden)
discrim_hidden = Add()([x_hidden, z_hidden])
discrim_hidden = Dense(10, activation='relu')(discrim_hidden)
discrim_hidden = Dense(20, activation='relu')(discrim_hidden)
discrim_logit = Dense(1, activation=None,
name='logit')(discrim_hidden)
discrim_out = Activation('sigmoid')(discrim_logit)
discriminator = Model(inputs=[x_input, z_input], outputs=discrim_out)
discriminator.compile(optimizer=Adam(lr=LEARNING_RATE),
loss='binary_crossentropy',
metrics=['binary_accuracy'])
ratio_estimator = Model(
inputs=discriminator.inputs,
outputs=discrim_logit)
SVG(model_to_dot(discriminator, show_shapes=True)
.create(prog='dot', format='svg'))
np.ones((32, 5)) + np.ones((16, 5))
z_grid_ratio = ratio_estimator.predict([np.ones((16, 1)), np.ones((32, 2))])
z_grid_ratio.shape
Initial density ratio, prior to any training
fig, ax = plt.subplots(figsize=(5, 5))
ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.magma)
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
plt.show()
discriminator.evaluate(prior.rvs(size=5), np.zeros(5))
$z_{\phi}(x, \epsilon)$
Here we only consider
$z_{\phi}(\epsilon)$
$z_{\phi}: \mathbb{R}^3 \to \mathbb{R}^2$
inference = Sequential()
inference.add(Dense(10, input_dim=NOISE_DIM, activation='relu'))
inference.add(Dense(20, activation='relu'))
inference.add(Dense(LATENT_DIM, activation=None))
inference.summary()
The variational parameters $\phi$ are the trainable weights of the approximate inference model
phi = inference.trainable_weights
phi
SVG(model_to_dot(inference, show_shapes=True)
.create(prog='dot', format='svg'))
w_sample_prior = prior.rvs(size=BATCH_SIZE)
w_sample_prior.shape
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
w_sample_posterior = inference.predict(eps)
w_sample_posterior.shape
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
fig, ax = plt.subplots(figsize=(5, 5))
ax.contourf(w1, w2,
np.exp(log_prior+np.sum(llhs, axis=2)),
cmap=plt.cm.magma)
ax.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
plt.show()
metrics = discriminator.evaluate(inputs, targets)
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
metrics
metrics_dict = dict(zip(discriminator.metrics_names, metrics))
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(9, 4))
metrics_plots = {k:ax1.plot([], label=k)[0]
for k in ['loss']} # discriminator.metrics_names}
ax1.set_xlabel('epoch')
ax1.legend(loc='upper left')
ax2.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')
ax2.set_xlabel('$w_1$')
ax2.set_ylabel('$w_2$')
ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)
plt.show()
def train_animate(epoch_num, prog_bar, batch_size=200, steps_per_epoch=15):
# Single training epoch
for step in tnrange(steps_per_epoch, unit='step', leave=False):
w_sample_prior = prior.rvs(size=batch_size)
eps = np.random.randn(batch_size, NOISE_DIM)
w_sample_posterior = inference.predict(eps)
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(batch_size), np.ones(batch_size)))
metrics = discriminator.train_on_batch(inputs, targets)
# Plot Metrics
metrics_dict = dict(zip(discriminator.metrics_names, metrics))
for metric in metrics_plots:
metrics_plots[metric].set_xdata(np.append(metrics_plots[metric].get_xdata(),
epoch_num))
metrics_plots[metric].set_ydata(np.append(metrics_plots[metric].get_ydata(),
metrics_dict[metric]))
metrics_plots[metric].set_label('{} ({:.2f})' \
.format(metric,
metrics_dict[metric]))
ax1.set_xlabel('epoch {:2d}'.format(epoch_num))
ax1.legend(loc='upper left')
ax1.relim()
ax1.autoscale_view()
# Contour Plot
ax2.cla()
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
ax2.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')
ax2.set_xlabel('$w_1$')
ax2.set_ylabel('$w_2$')
ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)
# Progress Bar Updates
prog_bar.update()
prog_bar.set_postfix(**metrics_dict)
return list(metrics_plots.values())
# main training loop is managed by higher-order
# FuncAnimation which makes calls to an `animate`
# function that encapsulates the logic of single
# training epoch. Has benefit of producing
# animation but can incur significant overhead
with tqdm_notebook(total=PRETRAIN_EPOCHS,
unit='epoch', leave=True) as prog_bar:
anim = FuncAnimation(fig,
train_animate,
fargs=(prog_bar,),
frames=PRETRAIN_EPOCHS,
interval=200, # 5 fps
blit=True)
anim_html5_video = anim.to_html5_video()
HTML(anim_html5_video)
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
metrics = discriminator.evaluate(inputs, targets)
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
fig, ax = plt.subplots(figsize=(5, 5))
ax.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')
metrics_dict = dict(zip(discriminator.metrics_names, metrics))
props = dict(boxstyle='round', facecolor='w', alpha=0.5)
ax.text(0.05, 0.05,
('accuracy: {binary_accuracy:.2f}\n'
'loss: {loss:.2f}').format(**metrics_dict),
transform=ax.transAxes, bbox=props)
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
plt.show()
def set_trainable(model, trainable):
"""inorder traversal"""
model.trainable = trainable
if isinstance(model, Model): # i.e. has layers
for layer in model.layers:
set_trainable(layer, trainable)
y_pred = K.sigmoid(K.dot(
K.constant(w_grid),
K.transpose(K.constant(X))))
y_pred
y_true = K.ones((300, 300, 1))*K.constant(y)
y_true
llhs_keras = - K.binary_crossentropy(
y_pred,
y_true,
from_logits=False)
sess = K.get_session()
np.allclose(np.sum(llhs, axis=-1),
sess.run(K.sum(llhs_keras, axis=-1)))
fig, ax = plt.subplots(figsize=(5, 5))
ax.contourf(w1, w2, sess.run(K.sum(llhs_keras, axis=-1)),
cmap=plt.cm.magma)
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
plt.show()
def make_elbo(ratio_estimator):
set_trainable(ratio_estimator, False)
def elbo(y_true, w_sample):
kl_estimate = ratio_estimator(w_sample)
y_pred = K.dot(w_sample, K.transpose(K.constant(X)))
log_likelihood = - K.binary_crossentropy(y_pred, y_true,
from_logits=True)
return K.mean(2.*log_likelihood-kl_estimate, axis=-1)
return elbo
elbo = make_elbo(ratio_estimator)
fig, ax = plt.subplots(figsize=(5, 5))
ax.contourf(w1, w2, sess.run(elbo(y_true, K.constant(w_grid))),
cmap=plt.cm.magma)
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
plt.show()
inference_loss = lambda y_true, w_sample: -make_elbo(ratio_estimator)(y_true, w_sample)
inference.compile(loss=inference_loss,
optimizer=Adam(lr=LEARNING_RATE))
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
y_true = K.repeat_elements(K.expand_dims(K.constant(y), axis=0),
axis=0, rep=BATCH_SIZE)
y_true
sess.run(K.mean(elbo(y_true, inference(K.constant(eps))), axis=-1))
inference.evaluate(eps, np.tile(y, reps=(BATCH_SIZE, 1)))
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(9, 4))
global_epoch = 0
loss_plot_inference, = ax1.plot([], label='inference')
loss_plot_discrim, = ax1.plot([], label='discriminator')
ax1.set_xlabel('epoch')
ax1.set_ylabel('loss')
ax1.legend(loc='upper left')
ax2.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')
ax2.set_xlabel('$w_1$')
ax2.set_ylabel('$w_2$')
ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)
plt.show()
def train_animate(epoch_num, prog_bar, batch_size=200,
steps_per_epoch=15):
global global_epoch, loss_plot_inference, loss_plot_discrim
# Single training epoch
## Ratio estimator training
set_trainable(discriminator, True)
for _ in tnrange(3*50, unit='step', desc='discriminator',
leave=False):
w_sample_prior = prior.rvs(size=BATCH_SIZE)
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
w_sample_posterior = inference.predict(eps)
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
metrics_discrim = discriminator.train_on_batch(inputs, targets)
metrics_dict_discrim = dict(zip(discriminator.metrics_names,
np.atleast_1d(metrics_discrim)))
## Inference model training
set_trainable(ratio_estimator, False)
y_tiled = np.tile(y, reps=(BATCH_SIZE, 1))
for _ in tnrange(1, unit='step', desc='inference', leave=False):
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
metrics_inference = inference.train_on_batch(eps, y_tiled)
metrics_dict_inference = dict(zip(inference.metrics_names,
np.atleast_1d(metrics_inference)))
global_epoch += 1
# Plot Loss
loss_plot_inference.set_xdata(np.append(loss_plot_inference.get_xdata(),
global_epoch))
loss_plot_inference.set_ydata(np.append(loss_plot_inference.get_ydata(),
metrics_dict_inference['loss']))
loss_plot_inference.set_label('inference ({:.2f})' \
.format(metrics_dict_inference['loss']))
loss_plot_discrim.set_xdata(np.append(loss_plot_discrim.get_xdata(),
global_epoch))
loss_plot_discrim.set_ydata(np.append(loss_plot_discrim.get_ydata(),
metrics_dict_discrim['loss']))
loss_plot_discrim.set_label('discriminator ({:.2f})' \
.format(metrics_dict_discrim['loss']))
ax1.set_xlabel('epoch {:2d}'.format(global_epoch))
ax1.legend(loc='upper left')
ax1.relim()
ax1.autoscale_view()
# Contour Plot
ax2.cla()
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
ax2.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')
ax2.set_xlabel('$w_1$')
ax2.set_ylabel('$w_2$')
ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)
# Progress Bar Updates
prog_bar.update()
prog_bar.set_postfix(loss_inference=metrics_dict_inference['loss'],
loss_discriminator=metrics_dict_discrim['loss'])
return loss_plot_inference, loss_plot_discrim
with tqdm_notebook(total=50,
unit='epoch', leave=True) as prog_bar:
anim = FuncAnimation(fig,
train_animate,
fargs=(prog_bar,),
frames=50,
interval=200, # 5 fps
blit=True)
anim_html5_video = anim.to_html5_video()
HTML(anim_html5_video)
with tqdm_notebook(total=50,
unit='epoch', leave=True) as prog_bar:
anim = FuncAnimation(fig,
train_animate,
fargs=(prog_bar,),
frames=50,
interval=200, # 5 fps
blit=True)
anim_html5_video = anim.to_html5_video()
HTML(anim_html5_video)
with tqdm_notebook(total=50,
unit='epoch', leave=True) as prog_bar:
anim = FuncAnimation(fig,
train_animate,
fargs=(prog_bar,),
frames=50,
interval=200, # 5 fps
blit=True)
anim_html5_video = anim.to_html5_video()
HTML(anim_html5_video)
with tqdm_notebook(total=50,
unit='epoch', leave=True) as prog_bar:
anim = FuncAnimation(fig,
train_animate,
fargs=(prog_bar,),
frames=50,
interval=200, # 5 fps
blit=True)
anim_html5_video = anim.to_html5_video()
HTML(anim_html5_video)
w_sample_prior = prior.rvs(size=128)
eps = np.random.randn(256, NOISE_DIM)
w_sample_posterior = inference.predict(eps)
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(128), np.ones(256)))
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(9, 4))
ax1.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax1.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')
ax1.set_xlabel('$w_1$')
ax1.set_xlim(w_min, w_max)
ax1.set_ylim(w_min, w_max)
ax2.contourf(w1, w2, np.sum(llhs, axis=2),
cmap=plt.cm.magma)
ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')
ax2.set_xlabel('$w_1$')
ax2.set_ylabel('$w_2$')
ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)
plt.show()
eps = np.random.randn(5000, NOISE_DIM)
w_sample_posterior = inference.predict(eps)
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(9, 4))
ax1.contourf(w1, w2,
np.exp(log_prior+np.sum(llhs, axis=2)),
cmap=plt.cm.magma)
ax1.scatter(*inference.predict(eps[::10]).T,
s=4.**2, alpha=.6, cmap='coolwarm_r')
ax1.set_xlabel('$w_1$')
ax1.set_ylabel('$w_2$')
ax1.set_xlim(w_min, w_max)
ax1.set_ylim(w_min, w_max)
sns.kdeplot(*inference.predict(eps).T,
cmap='magma', ax=ax2)
ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)
plt.show()
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
import numpy as np
import keras.backend as K
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import logistic, multivariate_normal, norm
from scipy.special import expit
from keras.models import Model, Sequential
from keras.layers import Activation, Dense, Dot, Input
from keras.optimizers import Adam
from keras.utils.vis_utils import model_to_dot
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FuncAnimation
from IPython.display import HTML, SVG, display_html
from tqdm import tnrange, tqdm_notebook
# display animation inline
plt.rc('animation', html='html5')
plt.style.use('seaborn-notebook')
sns.set_context('notebook')
np.set_printoptions(precision=2,
edgeitems=3,
linewidth=80,
suppress=True)
K.tf.__version__
LATENT_DIM = 2
NOISE_DIM = 3
BATCH_SIZE = 200
PRIOR_VARIANCE = 2.
LEARNING_RATE = 3e-3
PRETRAIN_EPOCHS = 60
w_min, w_max = -5, 5
w1, w2 = np.mgrid[w_min:w_max:300j, w_min:w_max:300j]
w_grid = np.dstack((w1, w2))
w_grid.shape
prior = multivariate_normal(mean=np.zeros(LATENT_DIM),
cov=PRIOR_VARIANCE)
log_prior = prior.logpdf(w_grid)
log_prior.shape
log_prior = -np.sum(w_grid**2, axis=2)/2/PRIOR_VARIANCE
log_prior.shape
fig, ax = plt.subplots(figsize=(5, 5))
ax.contourf(w1, w2, log_prior, cmap='magma')
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
plt.show()
x1 = np.array([ 1.5, 1.])
x2 = np.array([-1.5, 1.])
x3 = np.array([ .5, -1.])
X = np.vstack((x1, x2, x3))
X.shape
y1 = 1
y2 = 1
y3 = 0
y = np.stack((y1, y2, y3))
y.shape
def log_likelihood(w, x, y):
# equiv. to negative binary cross entropy
return np.log(expit(np.dot(w.T, x)*(-1)**(1-y)))
llhs = log_likelihood(w_grid.T, X.T, y)
llhs.shape
fig, axes = plt.subplots(ncols=3, nrows=1, figsize=(6, 2))
fig.tight_layout()
for i, ax in enumerate(axes):
ax.contourf(w1, w2, llhs[::,::,i], cmap=plt.cm.magma)
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
ax.set_title('$p(y_{{{0}}} \mid x_{{{0}}}, w)$'.format(i+1))
ax.set_xlabel('$w_1$')
if not i:
ax.set_ylabel('$w_2$')
plt.show()
fig, ax = plt.subplots(figsize=(5, 5))
ax.contourf(w1, w2, np.sum(llhs, axis=2),
cmap=plt.cm.magma)
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
plt.show()
fig, ax = plt.subplots(figsize=(5, 5))
ax.contourf(w1, w2,
np.exp(log_prior+np.sum(llhs, axis=2)),
cmap='magma')
ax.scatter(*X.T, c=y, cmap='coolwarm', marker=',')
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
plt.show()
$T_{\psi}(x, z)$
Here we consider
$T_{\psi}(w)$
$T_{\psi} : \mathbb{R}^2 \to \mathbb{R}$
discriminator = Sequential(name='discriminator')
discriminator.add(Dense(10, input_dim=LATENT_DIM, activation='relu'))
discriminator.add(Dense(20, activation='relu'))
discriminator.add(Dense(1, activation=None, name='logit'))
discriminator.add(Activation('sigmoid'))
discriminator.compile(optimizer=Adam(lr=LEARNING_RATE),
loss='binary_crossentropy',
metrics=['binary_accuracy'])
ratio_estimator = Model(
inputs=discriminator.inputs,
outputs=discriminator.get_layer(name='logit').output)
SVG(model_to_dot(discriminator, show_shapes=True)
.create(prog='dot', format='svg'))
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
Initial density ratio, prior to any training
fig, ax = plt.subplots(figsize=(5, 5))
ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.magma)
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
plt.show()
discriminator.evaluate(prior.rvs(size=5), np.zeros(5))
$z_{\phi}(x, \epsilon)$
Here we only consider
$z_{\phi}(\epsilon)$
$z_{\phi}: \mathbb{R}^3 \to \mathbb{R}^2$
inference = Sequential()
inference.add(Dense(10, input_dim=NOISE_DIM, activation='relu'))
inference.add(Dense(20, activation='relu'))
inference.add(Dense(LATENT_DIM, activation=None))
inference.summary()
The variational parameters $\phi$ are the trainable weights of the approximate inference model
phi = inference.trainable_weights
phi
SVG(model_to_dot(inference, show_shapes=True)
.create(prog='dot', format='svg'))
w_sample_prior = prior.rvs(size=BATCH_SIZE)
w_sample_prior.shape
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
w_sample_posterior = inference.predict(eps)
w_sample_posterior.shape
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
fig, ax = plt.subplots(figsize=(5, 5))
ax.contourf(w1, w2,
np.exp(log_prior+np.sum(llhs, axis=2)),
cmap=plt.cm.magma)
ax.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
plt.show()
metrics = discriminator.evaluate(inputs, targets)
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
metrics
metrics_dict = dict(zip(discriminator.metrics_names, metrics))
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(9, 4))
metrics_plots = {k:ax1.plot([], label=k)[0]
for k in ['loss']} # discriminator.metrics_names}
ax1.set_xlabel('epoch')
ax1.legend(loc='upper left')
ax2.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')
ax2.set_xlabel('$w_1$')
ax2.set_ylabel('$w_2$')
ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)
plt.show()
def train_animate(epoch_num, prog_bar, batch_size=200, steps_per_epoch=15):
# Single training epoch
for step in tnrange(steps_per_epoch, unit='step', leave=False):
w_sample_prior = prior.rvs(size=batch_size)
eps = np.random.randn(batch_size, NOISE_DIM)
w_sample_posterior = inference.predict(eps)
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(batch_size), np.ones(batch_size)))
metrics = discriminator.train_on_batch(inputs, targets)
# Plot Metrics
metrics_dict = dict(zip(discriminator.metrics_names, metrics))
for metric in metrics_plots:
metrics_plots[metric].set_xdata(np.append(metrics_plots[metric].get_xdata(),
epoch_num))
metrics_plots[metric].set_ydata(np.append(metrics_plots[metric].get_ydata(),
metrics_dict[metric]))
metrics_plots[metric].set_label('{} ({:.2f})' \
.format(metric,
metrics_dict[metric]))
ax1.set_xlabel('epoch {:2d}'.format(epoch_num))
ax1.legend(loc='upper left')
ax1.relim()
ax1.autoscale_view()
# Contour Plot
ax2.cla()
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
ax2.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')
ax2.set_xlabel('$w_1$')
ax2.set_ylabel('$w_2$')
ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)
# Progress Bar Updates
prog_bar.update()
prog_bar.set_postfix(**metrics_dict)
return list(metrics_plots.values())
# main training loop is managed by higher-order
# FuncAnimation which makes calls to an `animate`
# function that encapsulates the logic of single
# training epoch. Has benefit of producing
# animation but can incur significant overhead
with tqdm_notebook(total=PRETRAIN_EPOCHS,
unit='epoch', leave=True) as prog_bar:
anim = FuncAnimation(fig,
train_animate,
fargs=(prog_bar,),
frames=PRETRAIN_EPOCHS,
interval=200, # 5 fps
blit=True)
anim_html5_video = anim.to_html5_video()
HTML(anim_html5_video)
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
metrics = discriminator.evaluate(inputs, targets)
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
fig, ax = plt.subplots(figsize=(5, 5))
ax.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')
metrics_dict = dict(zip(discriminator.metrics_names, metrics))
props = dict(boxstyle='round', facecolor='w', alpha=0.5)
ax.text(0.05, 0.05,
('accuracy: {binary_accuracy:.2f}\n'
'loss: {loss:.2f}').format(**metrics_dict),
transform=ax.transAxes, bbox=props)
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
plt.show()
fig, ax = plt.subplots(figsize=(5, 5))
ax.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')
metrics_dict = dict(zip(discriminator.metrics_names, metrics))
props = dict(boxstyle='round', facecolor='w', alpha=0.5)
ax.text(0.05, 0.05,
('accuracy: {binary_accuracy:.2f}\n'
'loss: {loss:.2f}').format(**metrics_dict),
transform=ax.transAxes, bbox=props)
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
plt.show()
def set_trainable(model, trainable):
"""inorder traversal"""
model.trainable = trainable
if isinstance(model, Model): # i.e. has layers
for layer in model.layers:
set_trainable(layer, trainable)
y_pred = K.sigmoid(K.dot(
K.constant(w_grid),
K.transpose(K.constant(X))))
y_pred
y_true = K.ones((300, 300, 1))*K.constant(y)
y_true
llhs_keras = - K.binary_crossentropy(
y_pred,
y_true,
from_logits=False)
sess = K.get_session()
np.allclose(np.sum(llhs, axis=-1),
sess.run(K.sum(llhs_keras, axis=-1)))
fig, ax = plt.subplots(figsize=(5, 5))
ax.contourf(w1, w2, sess.run(K.sum(llhs_keras, axis=-1)),
cmap=plt.cm.magma)
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
plt.show()
def make_elbo(ratio_estimator):
set_trainable(ratio_estimator, False)
def elbo(y_true, w_sample):
kl_estimate = ratio_estimator(w_sample)
y_pred = K.dot(w_sample, K.transpose(K.constant(X)))
log_likelihood = - K.binary_crossentropy(y_pred, y_true,
from_logits=True)
return K.mean(log_likelihood-kl_estimate, axis=-1)
return elbo
inference_loss = lambda y_true, w_sample: -make_elbo(ratio_estimator)(y_true, w_sample)
fig, ax = plt.subplots(figsize=(5, 5))
ax.contourf(w1, w2, sess.run(inference_loss(y_true, K.constant(w_grid))),
cmap=plt.cm.magma)
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
plt.show()
inference.compile(loss=inference_loss,
optimizer=Adam(lr=LEARNING_RATE))
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
y_true = K.repeat_elements(K.expand_dims(K.constant(y), axis=0),
axis=0, rep=BATCH_SIZE)
y_true
sess.run(K.mean(inference_loss(y_true, inference(K.constant(eps))), axis=-1))
inference.evaluate(eps, np.tile(y, reps=(BATCH_SIZE, 1)))
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(9, 4))
global_epoch = 0
loss_plot_inference, = ax1.plot([], label='inference')
loss_plot_discrim, = ax1.plot([], label='discriminator')
ax1.set_xlabel('epoch')
ax1.set_ylabel('loss')
ax1.legend(loc='upper left')
ax2.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')
ax2.set_xlabel('$w_1$')
ax2.set_ylabel('$w_2$')
ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)
plt.show()
def train_animate(epoch_num, prog_bar, batch_size=200,
steps_per_epoch=15):
global global_epoch, loss_plot_inference, loss_plot_discrim
# Single training epoch
## Ratio estimator training
set_trainable(discriminator, True)
for _ in tnrange(3*50, unit='step', desc='discriminator',
leave=False):
w_sample_prior = prior.rvs(size=BATCH_SIZE)
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
w_sample_posterior = inference.predict(eps)
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
metrics_discrim = discriminator.train_on_batch(inputs, targets)
metrics_dict_discrim = dict(zip(discriminator.metrics_names,
np.atleast_1d(metrics_discrim)))
## Inference model training
set_trainable(ratio_estimator, False)
y_tiled = np.tile(y, reps=(BATCH_SIZE, 1))
for _ in tnrange(1, unit='step', desc='inference', leave=False):
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
metrics_inference = inference.train_on_batch(eps, y_tiled)
metrics_dict_inference = dict(zip(inference.metrics_names,
np.atleast_1d(metrics_inference)))
global_epoch += 1
# Plot Loss
loss_plot_inference.set_xdata(np.append(loss_plot_inference.get_xdata(),
global_epoch))
loss_plot_inference.set_ydata(np.append(loss_plot_inference.get_ydata(),
metrics_dict_inference['loss']))
loss_plot_inference.set_label('inference ({:.2f})' \
.format(metrics_dict_inference['loss']))
loss_plot_discrim.set_xdata(np.append(loss_plot_discrim.get_xdata(),
global_epoch))
loss_plot_discrim.set_ydata(np.append(loss_plot_discrim.get_ydata(),
metrics_dict_discrim['loss']))
loss_plot_discrim.set_label('discriminator ({:.2f})' \
.format(metrics_dict_discrim['loss']))
ax1.set_xlabel('epoch {:2d}'.format(global_epoch))
ax1.legend(loc='upper left')
ax1.relim()
ax1.autoscale_view()
# Contour Plot
ax2.cla()
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
ax2.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')
ax2.set_xlabel('$w_1$')
ax2.set_ylabel('$w_2$')
ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)
# Progress Bar Updates
prog_bar.update()
prog_bar.set_postfix(loss_inference=metrics_dict_inference['loss'],
loss_discriminator=metrics_dict_discrim['loss'])
return loss_plot_inference, loss_plot_discrim
with tqdm_notebook(total=50,
unit='epoch', leave=True) as prog_bar:
anim = FuncAnimation(fig,
train_animate,
fargs=(prog_bar,),
frames=50,
interval=200, # 5 fps
blit=True)
anim_html5_video = anim.to_html5_video()
HTML(anim_html5_video)
with tqdm_notebook(total=50,
unit='epoch', leave=True) as prog_bar:
anim = FuncAnimation(fig,
train_animate,
fargs=(prog_bar,),
frames=50,
interval=200, # 5 fps
blit=True)
anim_html5_video = anim.to_html5_video()
HTML(anim_html5_video)
with tqdm_notebook(total=50,
unit='epoch', leave=True) as prog_bar:
anim = FuncAnimation(fig,
train_animate,
fargs=(prog_bar,),
frames=50,
interval=200, # 5 fps
blit=True)
anim_html5_video = anim.to_html5_video()
HTML(anim_html5_video)
with tqdm_notebook(total=50,
unit='epoch', leave=True) as prog_bar:
anim = FuncAnimation(fig,
train_animate,
fargs=(prog_bar,),
frames=50,
interval=200, # 5 fps
blit=True)
anim_html5_video = anim.to_html5_video()
HTML(anim_html5_video)
w_sample_prior = prior.rvs(size=128)
eps = np.random.randn(256, NOISE_DIM)
w_sample_posterior = inference.predict(eps)
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(128), np.ones(256)))
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(9, 4))
ax1.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax1.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')
ax1.set_xlabel('$w_1$')
ax1.set_xlim(w_min, w_max)
ax1.set_ylim(w_min, w_max)
ax2.contourf(w1, w2, np.sum(llhs, axis=2),
cmap=plt.cm.magma)
ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')
ax2.set_xlabel('$w_1$')
ax2.set_ylabel('$w_2$')
ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)
plt.show()
eps = np.random.randn(5000, NOISE_DIM)
w_sample_posterior = inference.predict(eps)
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(9, 4))
ax1.contourf(w1, w2,
np.exp(log_prior+np.sum(llhs, axis=2)),
cmap=plt.cm.magma)
ax1.scatter(*inference.predict(eps[::10]).T,
s=4.**2, alpha=.6, cmap='coolwarm_r')
ax1.set_xlabel('$w_1$')
ax1.set_ylabel('$w_2$')
ax1.set_xlim(w_min, w_max)
ax1.set_ylim(w_min, w_max)
sns.kdeplot(*inference.predict(eps).T,
cmap='magma', ax=ax2)
ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)
plt.show()
output = expit(np.random.randn(256))
target = np.hstack((np.zeros(128), np.ones(128)))
2*K.mean(K.binary_crossentropy(output=K.constant(output),
target=K.constant(target))).eval(session=sess)
np.mean(-np.log(output[128:])-np.log(1-output[:128]))
(-np.log(output[:128])-np.log(1-output[128:])).shape
p1[:128][0]
p1[128:][0]
ratio_estimator.get_weights()[0]
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
import numpy as np
import keras.backend as K
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import logistic, multivariate_normal, norm
from scipy.special import expit
from keras.models import Model, Sequential
from keras.layers import Activation, Dense, Dot, Input
from keras.optimizers import Adam
from keras.utils.vis_utils import model_to_dot
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FuncAnimation
from IPython.display import HTML, SVG, display_html
from tqdm import tnrange, tqdm_notebook
# display animation inline
plt.rc('animation', html='html5')
plt.style.use('seaborn-notebook')
sns.set_context('notebook')
np.set_printoptions(precision=2,
edgeitems=3,
linewidth=80,
suppress=True)
K.tf.__version__
LATENT_DIM = 2
NOISE_DIM = 3
BATCH_SIZE = 200
PRIOR_VARIANCE = 2.
LEARNING_RATE = 3e-3
PRETRAIN_EPOCHS = 60
w_min, w_max = -5, 5
w1, w2 = np.mgrid[w_min:w_max:300j, w_min:w_max:300j]
w_grid = np.dstack((w1, w2))
w_grid.shape
prior = multivariate_normal(mean=np.zeros(LATENT_DIM),
cov=PRIOR_VARIANCE)
log_prior = prior.logpdf(w_grid)
log_prior.shape
fig, ax = plt.subplots(figsize=(5, 5))
ax.contourf(w1, w2, log_prior, cmap='magma')
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
plt.show()
x1 = np.array([ 1.5, 1.])
x2 = np.array([-1.5, 1.])
x3 = np.array([ .5, -1.])
X = np.vstack((x1, x2, x3))
X.shape
y1 = 1
y2 = 1
y3 = 0
y = np.stack((y1, y2, y3))
y.shape
def log_likelihood(w, x, y):
# equiv. to negative binary cross entropy
return np.log(expit(np.dot(w.T, x)*(-1)**(1-y)))
llhs = log_likelihood(w_grid.T, X.T, y)
llhs.shape
fig, axes = plt.subplots(ncols=3, nrows=1, figsize=(6, 2))
fig.tight_layout()
for i, ax in enumerate(axes):
ax.contourf(w1, w2, llhs[::,::,i], cmap=plt.cm.magma)
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
ax.set_title('$p(y_{{{0}}} \mid x_{{{0}}}, w)$'.format(i+1))
ax.set_xlabel('$w_1$')
if not i:
ax.set_ylabel('$w_2$')
plt.show()
fig, ax = plt.subplots(figsize=(5, 5))
ax.contourf(w1, w2, np.sum(llhs, axis=2),
cmap=plt.cm.magma)
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
plt.show()
fig, ax = plt.subplots(figsize=(5, 5))
ax.contourf(w1, w2,
np.exp(log_prior+np.sum(llhs, axis=2)),
cmap='magma')
ax.scatter(*X.T, c=y, cmap='coolwarm', marker=',')
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
plt.show()
$T_{\psi}(x, z)$
Here we consider
$T_{\psi}(w)$
$T_{\psi} : \mathbb{R}^2 \to \mathbb{R}$
discriminator = Sequential(name='discriminator')
discriminator.add(Dense(10, input_dim=LATENT_DIM, activation='relu'))
discriminator.add(Dense(20, activation='relu'))
discriminator.add(Dense(1, activation=None, name='logit'))
discriminator.add(Activation('sigmoid'))
discriminator.compile(optimizer=Adam(lr=LEARNING_RATE),
loss='binary_crossentropy',
metrics=['binary_accuracy'])
ratio_estimator = Model(
inputs=discriminator.inputs,
outputs=discriminator.get_layer(name='logit').output)
SVG(model_to_dot(discriminator, show_shapes=True)
.create(prog='dot', format='svg'))
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
Initial density ratio, prior to any training
fig, ax = plt.subplots(figsize=(5, 5))
ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.magma)
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
plt.show()
discriminator.evaluate(prior.rvs(size=5), np.zeros(5))
$z_{\phi}(x, \epsilon)$
Here we only consider
$z_{\phi}(\epsilon)$
$z_{\phi}: \mathbb{R}^3 \to \mathbb{R}^2$
inference = Sequential()
inference.add(Dense(10, input_dim=NOISE_DIM, activation='relu'))
inference.add(Dense(20, activation='relu'))
inference.add(Dense(LATENT_DIM, activation=None))
inference.summary()
The variational parameters $\phi$ are the trainable weights of the approximate inference model
phi = inference.trainable_weights
phi
SVG(model_to_dot(inference, show_shapes=True)
.create(prog='dot', format='svg'))
w_sample_prior = prior.rvs(size=BATCH_SIZE)
w_sample_prior.shape
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
w_sample_posterior = inference.predict(eps)
w_sample_posterior.shape
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
fig, ax = plt.subplots(figsize=(5, 5))
ax.contourf(w1, w2,
np.exp(log_prior+np.sum(llhs, axis=2)),
cmap=plt.cm.magma)
ax.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
plt.show()
metrics = discriminator.evaluate(inputs, targets)
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
metrics
metrics_dict = dict(zip(discriminator.metrics_names, metrics))
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(9, 4))
metrics_plots = {k:ax1.plot([], label=k)[0]
for k in ['loss']} # discriminator.metrics_names}
ax1.set_xlabel('epoch')
ax1.legend(loc='upper left')
ax2.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')
ax2.set_xlabel('$w_1$')
ax2.set_ylabel('$w_2$')
ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)
plt.show()
def train_animate(epoch_num, prog_bar, batch_size=200, steps_per_epoch=15):
# Single training epoch
for step in tnrange(steps_per_epoch, unit='step', leave=False):
w_sample_prior = prior.rvs(size=batch_size)
eps = np.random.randn(batch_size, NOISE_DIM)
w_sample_posterior = inference.predict(eps)
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(batch_size), np.ones(batch_size)))
metrics = discriminator.train_on_batch(inputs, targets)
# Plot Metrics
metrics_dict = dict(zip(discriminator.metrics_names, metrics))
for metric in metrics_plots:
metrics_plots[metric].set_xdata(np.append(metrics_plots[metric].get_xdata(),
epoch_num))
metrics_plots[metric].set_ydata(np.append(metrics_plots[metric].get_ydata(),
metrics_dict[metric]))
metrics_plots[metric].set_label('{} ({:.2f})' \
.format(metric,
metrics_dict[metric]))
ax1.set_xlabel('epoch {:2d}'.format(epoch_num))
ax1.legend(loc='upper left')
ax1.relim()
ax1.autoscale_view()
# Contour Plot
ax2.cla()
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
ax2.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')
ax2.set_xlabel('$w_1$')
ax2.set_ylabel('$w_2$')
ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)
# Progress Bar Updates
prog_bar.update()
prog_bar.set_postfix(**metrics_dict)
return list(metrics_plots.values())
# main training loop is managed by higher-order
# FuncAnimation which makes calls to an `animate`
# function that encapsulates the logic of single
# training epoch. Has benefit of producing
# animation but can incur significant overhead
with tqdm_notebook(total=PRETRAIN_EPOCHS,
unit='epoch', leave=True) as prog_bar:
anim = FuncAnimation(fig,
train_animate,
fargs=(prog_bar,),
frames=PRETRAIN_EPOCHS,
interval=200, # 5 fps
blit=True)
anim_html5_video = anim.to_html5_video()
HTML(anim_html5_video)
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
metrics = discriminator.evaluate(inputs, targets)
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
fig, ax = plt.subplots(figsize=(5, 5))
ax.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')
metrics_dict = dict(zip(discriminator.metrics_names, metrics))
props = dict(boxstyle='round', facecolor='w', alpha=0.5)
ax.text(0.05, 0.05,
('accuracy: {binary_accuracy:.2f}\n'
'loss: {loss:.2f}').format(**metrics_dict),
transform=ax.transAxes, bbox=props)
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
plt.show()
def set_trainable(model, trainable):
"""inorder traversal"""
model.trainable = trainable
if isinstance(model, Model): # i.e. has layers
for layer in model.layers:
set_trainable(layer, trainable)
y_pred = K.sigmoid(K.dot(
K.constant(w_grid),
K.transpose(K.constant(X))))
y_pred
y_true = K.ones((300, 300, 1))*K.constant(y)
y_true
llhs_keras = - K.binary_crossentropy(
y_pred,
y_true,
from_logits=False)
sess = K.get_session()
np.allclose(np.sum(llhs, axis=-1),
sess.run(K.sum(llhs_keras, axis=-1)))
fig, ax = plt.subplots(figsize=(5, 5))
ax.contourf(w1, w2, sess.run(K.sum(llhs_keras, axis=-1)),
cmap=plt.cm.magma)
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
plt.show()
def make_elbo(ratio_estimator):
set_trainable(ratio_estimator, False)
def elbo(y_true, w_sample):
kl_estimate = ratio_estimator(w_sample)
y_pred = K.dot(w_sample, K.transpose(K.constant(X)))
log_likelihood = - K.binary_crossentropy(y_pred, y_true,
from_logits=True)
return K.mean(log_likelihood-kl_estimate, axis=-1)
return elbo
elbo = make_elbo(ratio_estimator)
fig, ax = plt.subplots(figsize=(5, 5))
ax.contourf(w1, w2, sess.run(elbo(y_true, K.constant(w_grid))),
cmap=plt.cm.magma)
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
plt.show()
inference_loss = lambda y_true, w_sample: -make_elbo(ratio_estimator)(y_true, w_sample)
inference.compile(loss=inference_loss,
optimizer=Adam(lr=LEARNING_RATE))
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
y_true = K.repeat_elements(K.expand_dims(K.constant(y), axis=0),
axis=0, rep=BATCH_SIZE)
y_true
sess.run(K.mean(elbo(y_true, inference(K.constant(eps))), axis=-1))
inference.evaluate(eps, np.tile(y, reps=(BATCH_SIZE, 1)))
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(9, 4))
global_epoch = 0
loss_plot_inference, = ax1.plot([], label='inference')
loss_plot_discrim, = ax1.plot([], label='discriminator')
ax1.set_xlabel('epoch')
ax1.set_ylabel('loss')
ax1.legend(loc='upper left')
ax2.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')
ax2.set_xlabel('$w_1$')
ax2.set_ylabel('$w_2$')
ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)
plt.show()
def train_animate(epoch_num, prog_bar, batch_size=200,
steps_per_epoch=15):
global global_epoch, loss_plot_inference, loss_plot_discrim
# Single training epoch
## Ratio estimator training
set_trainable(discriminator, True)
for _ in tnrange(3*50, unit='step', desc='discriminator',
leave=False):
w_sample_prior = prior.rvs(size=BATCH_SIZE)
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
w_sample_posterior = inference.predict(eps)
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
metrics_discrim = discriminator.train_on_batch(inputs, targets)
metrics_dict_discrim = dict(zip(discriminator.metrics_names,
np.atleast_1d(metrics_discrim)))
## Inference model training
set_trainable(ratio_estimator, False)
y_tiled = np.tile(y, reps=(BATCH_SIZE, 1))
for _ in tnrange(1, unit='step', desc='inference', leave=False):
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
metrics_inference = inference.train_on_batch(eps, y_tiled)
metrics_dict_inference = dict(zip(inference.metrics_names,
np.atleast_1d(metrics_inference)))
global_epoch += 1
# Plot Loss
loss_plot_inference.set_xdata(np.append(loss_plot_inference.get_xdata(),
global_epoch))
loss_plot_inference.set_ydata(np.append(loss_plot_inference.get_ydata(),
metrics_dict_inference['loss']))
loss_plot_inference.set_label('inference ({:.2f})' \
.format(metrics_dict_inference['loss']))
loss_plot_discrim.set_xdata(np.append(loss_plot_discrim.get_xdata(),
global_epoch))
loss_plot_discrim.set_ydata(np.append(loss_plot_discrim.get_ydata(),
metrics_dict_discrim['loss']))
loss_plot_discrim.set_label('discriminator ({:.2f})' \
.format(metrics_dict_discrim['loss']))
ax1.set_xlabel('epoch {:2d}'.format(global_epoch))
ax1.legend(loc='upper left')
ax1.relim()
ax1.autoscale_view()
# Contour Plot
ax2.cla()
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
ax2.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')
ax2.set_xlabel('$w_1$')
ax2.set_ylabel('$w_2$')
ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)
# Progress Bar Updates
prog_bar.update()
prog_bar.set_postfix(loss_inference=metrics_dict_inference['loss'],
loss_discriminator=metrics_dict_discrim['loss'])
return loss_plot_inference, loss_plot_discrim
with tqdm_notebook(total=50,
unit='epoch', leave=True) as prog_bar:
anim = FuncAnimation(fig,
train_animate,
fargs=(prog_bar,),
frames=50,
interval=200, # 5 fps
blit=True)
anim_html5_video = anim.to_html5_video()
HTML(anim_html5_video)
with tqdm_notebook(total=50,
unit='epoch', leave=True) as prog_bar:
anim = FuncAnimation(fig,
train_animate,
fargs=(prog_bar,),
frames=50,
interval=200, # 5 fps
blit=True)
anim_html5_video = anim.to_html5_video()
HTML(anim_html5_video)
with tqdm_notebook(total=50,
unit='epoch', leave=True) as prog_bar:
anim = FuncAnimation(fig,
train_animate,
fargs=(prog_bar,),
frames=50,
interval=200, # 5 fps
blit=True)
anim_html5_video = anim.to_html5_video()
HTML(anim_html5_video)
with tqdm_notebook(total=50,
unit='epoch', leave=True) as prog_bar:
anim = FuncAnimation(fig,
train_animate,
fargs=(prog_bar,),
frames=50,
interval=200, # 5 fps
blit=True)
anim_html5_video = anim.to_html5_video()
HTML(anim_html5_video)
w_sample_prior = prior.rvs(size=128)
eps = np.random.randn(256, NOISE_DIM)
w_sample_posterior = inference.predict(eps)
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(128), np.ones(256)))
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(9, 4))
ax1.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax1.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')
ax1.set_xlabel('$w_1$')
ax1.set_xlim(w_min, w_max)
ax1.set_ylim(w_min, w_max)
ax2.contourf(w1, w2, np.sum(llhs, axis=2),
cmap=plt.cm.magma)
ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')
ax2.set_xlabel('$w_1$')
ax2.set_ylabel('$w_2$')
ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)
plt.show()
eps = np.random.randn(5000, NOISE_DIM)
w_sample_posterior = inference.predict(eps)
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(9, 4))
ax1.contourf(w1, w2,
np.exp(log_prior+np.sum(llhs, axis=2)),
cmap=plt.cm.magma)
ax1.scatter(*inference.predict(eps[::10]).T,
s=4.**2, alpha=.6, cmap='coolwarm_r')
ax1.set_xlabel('$w_1$')
ax1.set_ylabel('$w_2$')
ax1.set_xlim(w_min, w_max)
ax1.set_ylim(w_min, w_max)
sns.kdeplot(*inference.predict(eps).T,
cmap='magma', ax=ax2)
plt.show()
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
import numpy as np
import theano
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
from scipy.special import expit
from scipy.stats import logistic
from theano import tensor as T
from theano.tensor.shared_randomstreams import RandomStreams
from theano.printing import debugprint
from lasagne.updates import adam
from lasagne.utils import floatX
from lasagne.nonlinearities import sigmoid
from lasagne.layers import get_output, get_all_params
from lasagne.layers import (InputLayer,
DenseLayer,
NonlinearityLayer)
from matplotlib.animation import FuncAnimation
from IPython.display import HTML, SVG, display_html
from tqdm import tnrange, tqdm_notebook
# display animation inline
plt.rc('animation', html='html5')
plt.style.use('seaborn-notebook')
sns.set_context('notebook')
np.set_printoptions(precision=2,
edgeitems=3,
linewidth=80,
suppress=True)
LATENT_DIM = 2
NOISE_DIM = 3
BATCH_SIZE = 200
PRIOR_VARIANCE = 2.
LEARNING_RATE = 3e-3
PRETRAIN_EPOCHS = 60
w_min, w_max = -5, 5
w1, w2 = np.mgrid[w_min:w_max:300j, w_min:w_max:300j]
w_grid = np.dstack((w1, w2))
w_grid.shape
log_prior = -.5*np.sum(w_grid**2, axis=2)/PRIOR_VARIANCE
log_prior.shape
fig, ax = plt.subplots(figsize=(5, 5))
ax.contourf(w1, w2, log_prior, cmap='magma')
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
plt.show()
x1 = np.array([ 1.5, 1.])
x2 = np.array([-1.5, 1.])
x3 = np.array([ .5, -1.])
X = np.vstack((x1, x2, x3))
X.shape
y1 = 1
y2 = 1
y3 = -1
y = np.stack((y1, y2, y3))
y.shape
def log_likelihood(w, x, y):
# equiv. to negative binary cross entropy
return logistic.logcdf(y*(np.dot(w.T,x)))
llhs = log_likelihood(w_grid.T, X.T, y)
llhs.shape
fig, axes = plt.subplots(ncols=3, figsize=(6, 2))
fig.tight_layout()
for i, ax in enumerate(axes):
ax.contourf(w1, w2, llhs[::,::,i], cmap=plt.cm.magma)
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
ax.set_title('$p(y_{{{0}}} \mid x_{{{0}}}, w)$'.format(i+1))
ax.set_xlabel('$w_1$')
if not i:
ax.set_ylabel('$w_2$')
plt.show()
fig, ax = plt.subplots(figsize=(5, 5))
ax.contourf(w1, w2, np.sum(llhs, axis=2), cmap='magma')
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
plt.show()
# unnormalised log posterior
# only for illustration purposes
log_post = log_prior + np.sum(llhs, axis=2)
fig, ax = plt.subplots(figsize=(5, 5))
ax.contourf(w1, w2, np.exp(log_post), cmap='magma')
ax.scatter(*X.T, c=y, cmap='coolwarm', marker=',')
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
plt.show()
This part is for the actual GAN stuff. Here we define the generator and the discriminator networks in Lasagne, and code up the two loss functions in theano.
#defines a 'generator' network
def build_G(input_var=None, num_z = 3):
network = InputLayer(input_var=input_var, shape=(None, num_z))
network = DenseLayer(incoming = network, num_units=10)
network = DenseLayer(incoming = network, num_units=20)
network = DenseLayer(incoming = network, num_units=2, nonlinearity=None)
return network
#defines the 'discriminator network'
def build_D(input_var=None):
network = InputLayer(input_var=input_var, shape = (None, 2))
network = DenseLayer(incoming = network, num_units=10)
network = DenseLayer(incoming = network, num_units=20)
network = DenseLayer(incoming = network, num_units=1, nonlinearity=None)
normalised = NonlinearityLayer(incoming = network, nonlinearity = sigmoid)
return { 'unnorm':network, 'norm':normalised }
#variables for input (design matrix), output labels, GAN noise variable, weights
x_var = T.matrix('design matrix')
y_var = T.vector('labels')
z_var = T.matrix('GAN noise')
w_var = T.matrix('weights')
#theano variables for things like batchsize, learning rate, etc.
batchsize_var = T.scalar('batchsize', dtype='int32')
prior_variance_var = T.scalar('prior variance')
learningrate_var = T.scalar('learning rate')
#random numbers for sampling from the prior or from the GAN
srng = RandomStreams(seed=13574437)
z_rnd = srng.normal((batchsize_var,3))
prior_rnd = srng.normal((batchsize_var,2))
#instantiating the G and D networks
generator = build_G(z_var)
discriminator = build_D()
#these expressions are random samples from the generator and the prior, respectively
samples_from_grenerator = get_output(generator, z_rnd)
samples_from_prior = prior_rnd*T.sqrt(prior_variance_var)
#discriminator output for synthetic samples, both normalised and unnormalised (after/before sigmoid)
D_of_G = get_output(discriminator['norm'], inputs=samples_from_grenerator)
s_of_G = get_output(discriminator['unnorm'], inputs=samples_from_grenerator)
#discriminator output for real samples from the prior
D_of_prior = get_output(discriminator['norm'], inputs=samples_from_prior)
#loss of discriminator - simple binary cross-entropy loss
loss_D = -T.log(D_of_G).mean() - T.log(1-D_of_prior).mean()
#log likelihood for each synthetic w sampled from the generator
log_likelihood = T.log(
T.nnet.sigmoid(
(y_var.dimshuffle(0,'x','x')*(x_var.dimshuffle(0,1,'x') * samples_from_grenerator.dimshuffle('x', 1, 0))).sum(1)
)
).sum(0).mean()
#loss for G is the sum of unnormalised discriminator output and the negative log likelihood
loss_G = s_of_G.mean() - log_likelihood
#compiling theano functions:
evaluate_generator = theano.function(
[z_var],
get_output(generator),
allow_input_downcast=True
)
sample_generator = theano.function(
[batchsize_var],
samples_from_grenerator,
allow_input_downcast=True,
)
sample_prior = theano.function(
[prior_variance_var, batchsize_var],
samples_from_prior,
allow_input_downcast=True
)
params_D = get_all_params(discriminator['norm'], trainable=True)
updates_D = adam(
loss_D,
params_D,
learning_rate = learningrate_var
)
train_D = theano.function(
[learningrate_var, batchsize_var, prior_variance_var],
loss_D,
updates = updates_D,
allow_input_downcast = True
)
params_G = get_all_params(generator, trainable=True)
updates_G = adam(
loss_G,
params_G,
learning_rate = learningrate_var
)
train_G = theano.function(
[x_var, y_var, learningrate_var, batchsize_var],
loss_G,
updates = updates_G,
allow_input_downcast = True
)
evaluate_discriminator = theano.function(
[w_var],
get_output([discriminator['unnorm'],discriminator['norm']],w_var),
allow_input_downcast = True
)
#this is to evaluate the log-likelihood of an arbitrary set of w
llh_for_w = T.nnet.sigmoid((y_var.dimshuffle(0,'x','x')*(x_var.dimshuffle(0,1,'x') * w_var.dimshuffle('x', 1, 0))).sum(1))
evaluate_loglikelihood = theano.function(
[x_var, y_var, w_var],
llh_for_w,
allow_input_downcast = True
)
fig, ax = plt.subplots(figsize=(5, 5))
w_grid_ratio, _ = evaluate_discriminator(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300,300)
ax.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
plt.show()
w_sample_prior = sample_prior(PRIOR_VARIANCE, BATCH_SIZE)
w_sample_posterior = sample_generator(BATCH_SIZE)
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
fig, ax = plt.subplots(figsize=(5, 5))
ax.contourf(w1, w2, np.exp(log_post), cmap='magma')
ax.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
plt.show()
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(9, 4))
loss_plot, = ax1.plot([], label='loss')
ax1.set_xlabel('epoch')
ax1.legend(loc='upper left')
ax2.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')
ax2.set_xlabel('$w_1$')
ax2.set_ylabel('$w_2$')
ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)
plt.show()
def train_animate(epoch_num, prog_bar, batch_size=200, steps_per_epoch=15):
# Single training epoch
for step in tnrange(steps_per_epoch, unit='step', leave=False):
loss = np.asscalar(train_D(LEARNING_RATE,
batch_size,
PRIOR_VARIANCE))
w_sample_prior = sample_prior(PRIOR_VARIANCE, batch_size)
w_sample_posterior = sample_generator(batch_size)
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(batch_size), np.ones(batch_size)))
w_grid_ratio, _ = evaluate_discriminator(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300,300)
# Plot Loss
loss_plot.set_xdata(np.append(loss_plot.get_xdata(), epoch_num))
loss_plot.set_ydata(np.append(loss_plot.get_ydata(), loss))
loss_plot.set_label('loss ({:.2f})'.format(loss))
ax1.set_xlabel('epoch {:2d}'.format(epoch_num))
ax1.legend(loc='upper left')
ax1.relim()
ax1.autoscale_view()
# Contour Plot
ax2.cla()
ax2.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')
ax2.set_xlabel('$w_1$')
ax2.set_ylabel('$w_2$')
ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)
# Progress Bar Updates
prog_bar.update()
prog_bar.set_postfix(loss=loss)
return loss_plot,
# main training loop is managed by higher-order
# FuncAnimation which makes calls to an `animate`
# function that encapsulates the logic of single
# training epoch. Has benefit of producing
# animation but can incur significant overhead
with tqdm_notebook(total=PRETRAIN_EPOCHS,
unit='epoch', leave=True) as prog_bar:
anim = FuncAnimation(fig,
train_animate,
fargs=(prog_bar,),
frames=PRETRAIN_EPOCHS,
interval=200, # 5 fps
blit=True)
anim_html5_video = anim.to_html5_video()
HTML(anim_html5_video)
fig, ax = plt.subplots(figsize=(5, 5))
w_grid_ratio, _ = evaluate_discriminator(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300,300)
ax.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
plt.show()
llh_theano = evaluate_loglikelihood(X, y, w_grid.reshape(300*300, 2))
llh_theano.shape
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(9, 4))
ax1.contourf(w1, w2, np.sum(llhs, axis=2),
cmap='magma')
ax1.set_xlabel('$w_1$')
ax1.set_ylabel('$w_2$')
ax1.set_xlim(w_min, w_max)
ax1.set_ylim(w_min, w_max)
ax1.set_title('numpy loglikelihood')
ax2.contourf(w1, w2, np.sum(np.log(llh_theano), 0).reshape(300,300),
cmap='magma')
ax2.set_xlabel('$w_1$')
ax2.set_ylabel('$w_2$')
ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)
ax2.set_title('theano loglikelihood')
plt.show()
np.allclose(np.sum(llhs, axis=2),
np.sum(np.log(llh_theano), 0).reshape(300,300))
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(9, 4))
global_epoch = 0
plots_dict = {k:ax1.plot([], label=k)[0] for k in ('inference',
'discriminator',
'neg_log_likelihood',
'kl')}
ax1.set_xlabel('epoch')
ax1.set_ylabel('loss')
ax1.legend(loc='upper left')
ax2.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')
ax2.set_xlabel('$w_1$')
ax2.set_ylabel('$w_2$')
ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)
plt.show()
def train_animate(epoch_num, prog_bar, batch_size=200,
steps_per_epoch=15):
global global_epoch, plots_dict
plot_values = defaultdict(int)
# Single training epoch
## Ratio estimator training
for _ in tnrange(150, unit='step', desc='discriminator',
leave=False):
plot_values['discriminator'] = np.asscalar(train_D(LEARNING_RATE,
batch_size,
PRIOR_VARIANCE))
## Inference model training
for _ in tnrange(1, unit='step', desc='inference', leave=False):
np.asscalar(train_G(X, y, LEARNING_RATE, batch_size))
global_epoch += 1
w_sample_prior = sample_prior(PRIOR_VARIANCE, batch_size)
w_sample_posterior = sample_generator(batch_size)
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(batch_size), np.ones(batch_size)))
plot_values['kl'] = np.mean(evaluate_discriminator(w_sample_posterior)[0])
plot_values['neg_log_likelihood'] = -np.mean(np.sum(np.log(evaluate_loglikelihood(X, y, w_sample_posterior)), axis=0))
plot_values['inference'] = plot_values['kl'] + plot_values['neg_log_likelihood']
w_grid_ratio, _ = evaluate_discriminator(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300,300)
# Plot Loss
for k in plots_dict:
plots_dict[k].set_xdata(np.append(plots_dict[k].get_xdata(), global_epoch))
plots_dict[k].set_ydata(np.append(plots_dict[k].get_ydata(), plot_values[k]))
plots_dict[k].set_label('{} ({:.2f})'.format(k, plot_values[k]))
ax1.set_xlabel('epoch {:2d}'.format(global_epoch))
ax1.legend(loc='upper left')
ax1.relim()
ax1.autoscale_view()
# Contour Plot
ax2.cla()
ax2.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax2.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')
ax2.set_xlabel('$w_1$')
ax2.set_ylabel('$w_2$')
ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)
# Progress Bar Updates
prog_bar.update()
# prog_bar.set_postfix(loss_g=loss_g, loss_d=loss_d)
return list(plots_dict.values())
with tqdm_notebook(total=50,
unit='epoch', leave=True) as prog_bar:
anim = FuncAnimation(fig,
train_animate,
fargs=(prog_bar,),
frames=50,
interval=200, # 5 fps
blit=True)
anim_html5_video = anim.to_html5_video()
HTML(anim_html5_video)
with tqdm_notebook(total=50,
unit='epoch', leave=True) as prog_bar:
anim = FuncAnimation(fig,
train_animate,
fargs=(prog_bar,),
frames=50,
interval=200, # 5 fps
blit=True)
anim_html5_video = anim.to_html5_video()
HTML(anim_html5_video)
with tqdm_notebook(total=50,
unit='epoch', leave=True) as prog_bar:
anim = FuncAnimation(fig,
train_animate,
fargs=(prog_bar,),
frames=50,
interval=200, # 5 fps
blit=True)
anim_html5_video = anim.to_html5_video()
HTML(anim_html5_video)
with tqdm_notebook(total=50,
unit='epoch', leave=True) as prog_bar:
anim = FuncAnimation(fig,
train_animate,
fargs=(prog_bar,),
frames=50,
interval=200, # 5 fps
blit=True)
anim_html5_video = anim.to_html5_video()
HTML(anim_html5_video)
w_grid_ratio, _ = evaluate_discriminator(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300,300)
w_sample_prior = sample_prior(PRIOR_VARIANCE, 100)
w_sample_posterior = sample_generator(100)
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(100), np.ones(100)))
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(9, 4))
ax1.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax1.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')
ax1.set_xlabel('$w_1$')
ax1.set_xlim(w_min, w_max)
ax1.set_ylim(w_min, w_max)
ax1.set_title('estimated log density ratio $\Phi^{-1}(D)$')
ax2.contourf(w1, w2, np.sum(llhs, axis=2),
cmap=plt.cm.magma)
ax2.set_xlabel('$w_1$')
ax2.set_ylabel('$w_2$')
ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)
ax2.set_title('log likelihood')
plt.show()
fig, ax = plt.subplots( figsize=(5, 5))
ax.contourf(w1, w2,
np.exp(log_post),
cmap=plt.cm.magma)
ax.scatter(*sample_generator(1000).T,
s=4.**2, alpha=.6, cmap='coolwarm_r')
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
plt.show()
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(9, 4))
ax1.contourf(w1, w2,
np.exp(log_prior+np.sum(llhs, axis=2)),
cmap=plt.cm.magma)
ax1.scatter(*sample_generator(100).T,
s=4.**2, alpha=.6, cmap='coolwarm_r')
ax1.set_xlabel('$w_1$')
ax1.set_ylabel('$w_2$')
ax1.set_xlim(w_min, w_max)
ax1.set_ylim(w_min, w_max)
ax1.set_title('true posterior')
sns.kdeplot(*sample_generator(5000).T,
shade=True, cmap='magma', ax=ax2)
ax2.set_xlim(w_min, w_max)
ax2.set_ylim(w_min, w_max)
ax2.set_title('kde of approximate posterior')
plt.show()
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
import numpy as np
import keras.backend as K
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import logistic, multivariate_normal, norm
from scipy.special import expit
from keras.models import Model, Sequential
from keras.layers import Activation, Dense, Dot, Input
from keras.optimizers import Adam
from keras.utils.vis_utils import model_to_dot
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FuncAnimation
from IPython.display import SVG
from tqdm import tnrange
# display animation inline
plt.rc('animation', html='html5')
plt.style.use('seaborn-notebook')
sns.set_context('notebook')
np.set_printoptions(precision=2,
edgeitems=3,
linewidth=80,
suppress=True)
K.tf.__version__
LATENT_DIM = 2
NOISE_DIM = 3
BATCH_SIZE = 200
PRIOR_VARIANCE = 2.
LEARNING_RATE = 3e-3
w_min, w_max = -5, 5
w1, w2 = np.mgrid[w_min:w_max:300j, w_min:w_max:300j]
w_grid = np.dstack((w1, w2))
w_grid.shape
prior = multivariate_normal(mean=np.zeros(LATENT_DIM),
cov=PRIOR_VARIANCE)
log_prior = prior.logpdf(w_grid)
log_prior.shape
fig, ax = plt.subplots(figsize=(5, 5))
ax.contourf(w1, w2, log_prior, cmap='magma')
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
plt.show()
x1 = np.array([ 1.5, 1.])
x2 = np.array([-1.5, 1.])
x3 = np.array([ .5, -1.])
X = np.vstack((x1, x2, x3))
X.shape
y1 = 1
y2 = 1
y3 = 0
y = np.stack((y1, y2, y3))
y.shape
def log_likelihood(w, x, y):
# equiv. to negative binary cross entropy
return np.log(expit(np.dot(w.T, x)*(-1)**(1-y)))
llhs = log_likelihood(w_grid.T, X.T, y)
llhs.shape
fig, axes = plt.subplots(ncols=3, nrows=1, figsize=(6, 2))
fig.tight_layout()
for i, ax in enumerate(axes):
ax.contourf(w1, w2, llhs[::,::,i], cmap=plt.cm.magma)
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
ax.set_title('$p(y_{{{0}}} \mid x_{{{0}}}, w)$'.format(i+1))
ax.set_xlabel('$w_1$')
if not i:
ax.set_ylabel('$w_2$')
plt.show()
fig, ax = plt.subplots(figsize=(5, 5))
ax.contourf(w1, w2, np.sum(llhs, axis=2),
cmap=plt.cm.magma)
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
plt.show()
fig, ax = plt.subplots(figsize=(5, 5))
ax.contourf(w1, w2,
np.exp(log_prior+np.sum(llhs, axis=2)),
cmap='magma')
ax.scatter(*X.T, c=y, cmap='coolwarm', marker=',')
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
plt.show()
$T_{\psi}(x, z)$
Here we consider
$T_{\psi}(w)$
$T_{\psi} : \mathbb{R}^2 \to \mathbb{R}$
discriminator = Sequential(name='discriminator')
discriminator.add(Dense(10, input_dim=LATENT_DIM, activation='relu'))
discriminator.add(Dense(20, activation='relu'))
discriminator.add(Dense(1, activation=None, name='logit'))
discriminator.add(Activation('sigmoid'))
discriminator.compile(optimizer=Adam(lr=LEARNING_RATE),
loss='binary_crossentropy',
metrics=['binary_accuracy'])
ratio_estimator = Model(
inputs=discriminator.inputs,
outputs=discriminator.get_layer(name='logit').output)
SVG(model_to_dot(discriminator, show_shapes=True)
.create(prog='dot', format='svg'))
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
Initial density ratio, prior to any training
fig, ax = plt.subplots(figsize=(5, 5))
ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.magma)
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
plt.show()
discriminator.evaluate(prior.rvs(size=5), np.zeros(5))
$z_{\phi}(x, \epsilon)$
Here we only consider
$z_{\phi}(\epsilon)$
$z_{\phi}: \mathbb{R}^3 \to \mathbb{R}^2$
inference = Sequential()
inference.add(Dense(10, input_dim=NOISE_DIM, activation='relu'))
inference.add(Dense(20, activation='relu'))
inference.add(Dense(LATENT_DIM, activation=None))
inference.summary()
The variational parameters $\phi$ are the trainable weights of the approximate inference model
phi = inference.trainable_weights
phi
SVG(model_to_dot(inference, show_shapes=True)
.create(prog='dot', format='svg'))
w_sample_prior = prior.rvs(size=BATCH_SIZE)
w_sample_prior.shape
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
w_sample_posterior = inference.predict(eps)
w_sample_posterior.shape
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
fig, ax = plt.subplots(figsize=(5, 5))
ax.contourf(w1, w2,
np.exp(log_prior+np.sum(llhs, axis=2)),
cmap=plt.cm.magma)
ax.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
plt.show()
metrics = discriminator.evaluate(inputs, targets)
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
fig, ax = plt.subplots(figsize=(5, 5))
ax.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')
train_info = dict(zip(discriminator.metrics_names, metrics))
props = dict(boxstyle='round', facecolor='w', alpha=0.5)
ax.text(0.05, 0.05,
('accuracy: {binary_accuracy:.2f}\n'
'loss: {loss:.2f}').format(**train_info),
transform=ax.transAxes, bbox=props)
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
def train_animate(epoch_num, batch_size=200, steps_per_epoch=15):
for step in range(steps_per_epoch):
w_sample_prior = prior.rvs(size=batch_size)
eps = np.random.randn(batch_size, NOISE_DIM)
w_sample_posterior = inference.predict(eps)
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(batch_size), np.ones(batch_size)))
metrics = discriminator.train_on_batch(inputs, targets)
ax.cla()
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
ax.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')
train_info = dict(zip(discriminator.metrics_names, metrics))
train_info['epoch'] = epoch_num
props = dict(boxstyle='round', facecolor='w', alpha=0.5)
ax.text(0.05, 0.05,
('epoch: {epoch:2d}\n'
'accuracy: {binary_accuracy:.2f}\n'
'loss: {loss:.2f}').format(**train_info),
transform=ax.transAxes, bbox=props)
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
return ax
FuncAnimation(fig, train_animate, frames=60,
interval=200, # 5 fps
blit=False)
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
metrics = discriminator.evaluate(inputs, targets)
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
fig, ax = plt.subplots(figsize=(5, 5))
ax.contourf(w1, w2, w_grid_ratio, cmap='magma')
ax.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')
train_info = dict(zip(discriminator.metrics_names, metrics))
props = dict(boxstyle='round', facecolor='w', alpha=0.5)
ax.text(0.05, 0.05,
('accuracy: {binary_accuracy:.2f}\n'
'loss: {loss:.2f}').format(**train_info),
transform=ax.transAxes, bbox=props)
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
plt.show()
def set_trainable(model, trainable):
"""inorder traversal"""
model.trainable = trainable
if isinstance(model, Model): # i.e. has layers
for layer in model.layers:
set_trainable(layer, trainable)
y_pred = K.sigmoid(K.dot(
K.constant(w_grid),
K.transpose(K.constant(X))))
y_pred
y_true = K.ones((300, 300, 1))*K.constant(y)
y_true
llhs_keras = - K.binary_crossentropy(
y_pred,
y_true,
from_logits=False)
sess = K.get_session()
np.allclose(np.sum(llhs, axis=-1),
sess.run(K.sum(llhs_keras, axis=-1)))
fig, ax = plt.subplots(figsize=(5, 5))
ax.contourf(w1, w2, sess.run(K.sum(llhs_keras, axis=-1)),
cmap=plt.cm.magma)
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
plt.show()
def make_elbo(ratio_estimator):
set_trainable(ratio_estimator, False)
def elbo(y_true, w_sample):
kl_estimate = ratio_estimator(w_sample)
y_pred = K.dot(w_sample, K.transpose(K.constant(X)))
log_likelihood = - K.binary_crossentropy(y_pred, y_true,
from_logits=True)
return K.mean(log_likelihood-kl_estimate, axis=-1)
return elbo
elbo = make_elbo(ratio_estimator)
fig, ax = plt.subplots(figsize=(5, 5))
ax.contourf(w1, w2, sess.run(elbo(y_true, K.constant(w_grid))),
cmap=plt.cm.magma)
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
plt.show()
inference_loss = lambda y_true, w_sample: -make_elbo(ratio_estimator)(y_true, w_sample)
inference.compile(loss=inference_loss,
optimizer=Adam(lr=LEARNING_RATE))
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
y_true = K.repeat_elements(K.expand_dims(K.constant(y), axis=0),
axis=0, rep=BATCH_SIZE)
y_true
sess.run(K.mean(elbo(y_true, inference(K.constant(eps))), axis=-1))
inference.evaluate(eps, np.tile(y, reps=(BATCH_SIZE, 1)))
for epoch in tnrange(200, desc='epoch'):
set_trainable(ratio_estimator, False)
for _ in tnrange(1, desc='generator'):
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
metrics_inference = inference.train_on_batch(eps, np.tile(y, reps=(BATCH_SIZE, 1)))
set_trainable(discriminator, True)
for _ in tnrange(3*50, desc='discriminator'):
w_sample_prior = prior.rvs(size=BATCH_SIZE)
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
w_sample_posterior = inference.predict(eps)
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
metrics_discrim = discriminator.train_on_batch(inputs, targets)
inputs = np.vstack((w_sample_prior, w_sample_posterior))
targets = np.hstack((np.zeros(BATCH_SIZE), np.ones(BATCH_SIZE)))
metrics = discriminator.evaluate(inputs, targets)
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)
fig, ax = plt.subplots(figsize=(5, 5))
ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.gray)
ax.scatter(*inputs.T, c=targets, s=4.**2, alpha=.8, cmap='coolwarm')
train_info = dict(zip(discriminator.metrics_names, metrics))
props = dict(boxstyle='round', facecolor='w', alpha=0.5)
ax.text(0.05, 0.05,
('accuracy: {binary_accuracy:.2f}\n'
'loss: {loss:.2f}').format(**train_info),
transform=ax.transAxes, bbox=props)
ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')
ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)
plt.show()
This is your PhD, and it's ending one week at a time.
| 1 | 2 | 3 | 4 | |
| 1 | ||||
| 5 | ||||
| 10 | ||||
| 15 | ||||
| 20 | ||||
| 25 | ||||
| x | ||||
| x | ||||
| x | ||||
| 30 | x | |||
| x | ||||
| x | ||||
| x | ||||
| x | ||||
| 35 | x | |||
| x | ||||
| x | ||||
| x | ||||
| x | ||||
| 40 | x | |||
| x | ||||
| x | ||||
| x | ||||
| x | ||||
| 45 | x | |||
| x | ||||
| x | ||||
| x | ||||
| x | ||||
| 50 | x | |||
| x | ||||
| x |